package it.uniroma2.tk;

import it.uniroma2.util.tree.Tree;

import java.util.HashMap;
import java.util.Vector;

/**
 * @author lorenzo
 * This class implements the traditional Tree Kernel computation between Tree objects 
 */
public class TreeKernel {

	public static double lambda = 1;
	public static boolean lexicalized = false;
	private static int nodeCount = 0;
	private static HashMap<String,Double> deltaMatrix; 
	private static HashMap<Tree,Integer> nodeIndices; 
	
	public static double value(Tree a, Tree b) {
		deltaMatrix = new HashMap<String,Double>();
		nodeIndices = new HashMap<Tree,Integer>();
		nodeCount = 0;
		double sum = 0;
		for (Tree aa : allNodes(a))
			for (Tree bb : allNodes(b))
				sum += delta(aa,bb);
		return sum;
	}

	private static double delta(Tree a,Tree b) {
		double k = 0;
		if (!nodeIndices.containsKey(a)) {
			nodeIndices.put(a,nodeCount);
			nodeCount++;
		}
		if (!nodeIndices.containsKey(b)) {
			nodeIndices.put(b,nodeCount);
			nodeCount++;
		}
		if (deltaMatrix.containsKey(nodeIndices.get(a) + ":" +nodeIndices.get(b))) {
			return deltaMatrix.get(nodeIndices.get(a) + ":" +nodeIndices.get(b));
		}

		if (a.getChildren().size() == b.getChildren().size()) {
			if (a.getChildren().size() == 1 && a.getChildren().get(0).isTerminal() && b.getChildren().get(0).isTerminal()) {
				if (lexicalized && a.equals(b))
					k = 1;
			} else {
				if (productionCompare(a, b)) {
					k = 1;
					for (int i=0; i<a.getChildren().size(); i++) {
						k = k*(1+lambda*delta(a.getChildren().get(i),b.getChildren().get(i)));
					}
				} 
			}
		}
		deltaMatrix.put(nodeIndices.get(a) + ":" +nodeIndices.get(b),k);
		return k;
	}
	
	private static boolean productionCompare(Tree a, Tree b) {
		if (!a.getRootLabel().equals(b.getRootLabel()))
			return false;
		if (a.getChildren().size() != b.getChildren().size() || a.getChildren().size() == 0)
			return false;
		for (int i=0; i<a.getChildren().size(); i++)
			if (!a.getChildren().get(i).getRootLabel().equals(b.getChildren().get(i).getRootLabel()))
				return false;
		return true;
	} 
	
	private static Vector<Tree> allNodes(Tree node) {
		Vector<Tree> all = new Vector<Tree>();
		all.add(node);
		for (Tree child : node.getChildren())
			all.addAll(allNodes(child));
		return all;
	}
	
}
